import carla
from policies import BasePolicy
from srunner.scenariomanager.carla_data_provider import CarlaDataProvider

RED = carla.TrafficLightState.Red
YELLOW = carla.TrafficLightState.Yellow
GREEN = carla.TrafficLightState.Green

class HeuristicCommPolicy(BasePolicy):
    def __init__(self, agent_id):
        super().__init__(agent_id)
        self.step_count = 0
        self._map = None
        self._world = None
        self._traffic_light_stratus = {}

    def act(self):
        self.step_count += 1
        all_vehicles = CarlaDataProvider.get_world().get_actors().filter("*vehicle*")
        message = ""

        if self._world is None:
            self._world = CarlaDataProvider.get_world()
            self._map = self._world.get_map()
        # Check for red light violation
        red_light_violation_message = self.check_red_light_violation(all_vehicles)

        # If there are any abnormal behaviors, report them
        # If not, report everything is good
        if red_light_violation_message:
            message += red_light_violation_message
        else:
            message += "The traffic situation is safe."
        action = {"message": message}
        return action
    
    def check_red_light_violation(self, all_vehicles):
        message = ""
        for vehicle in all_vehicles:
            # Initialize Traffic Light Status for each vehicle
            if vehicle.id not in self._traffic_light_stratus:
                self._traffic_light_stratus[vehicle.id] = None
            # Update the traffic light status for each vehicle
            if vehicle.is_at_traffic_light():
                self._traffic_light_stratus[vehicle.id] = vehicle.get_traffic_light_state()
            # Check if the vehicle is running red light and is in the intersection
            vehicle_at_intersection = self._map.get_waypoint(vehicle.get_location()).is_intersection
            if (self._traffic_light_stratus[vehicle.id] == RED
                and vehicle_at_intersection
                and vehicle.get_velocity().length() > 0
            ):
                message += f"Vehicle {vehicle.id} is running red light. "
        return message